# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import doctest
import os
import sys
from enum import Enum
from typing import Dict, Iterable, List, TextIO


class Target(Enum):
    NAME = 0
    THRIFT = 1
    CPP2 = 2


THRIFT_HEADER = f"""# This file was generated by `thrift/test/testset/generator.py`
# {'@'}generated

namespace cpp2 apache.thrift.test.testset
"""

CPP2_HEADER = f"""// This file was generated by `thrift/test/testset/generator.py`
// {'@'}generated

#pragma once

#include <fatal/type/sequence.h>
#include <fatal/type/sort.h>
#include <thrift/conformance/cpp2/ThriftTypes.h>
#include <thrift/test/testset/gen-cpp2/testset_types.h>

namespace apache::thrift::test::testset {{

enum class FieldModifier {{
  Optional = 1,
  Required,
  Reference,
  Lazy,
}};

namespace detail {{

template <FieldModifier... Ms>
using mod_set = fatal::sort<fatal::sequence<FieldModifier, Ms...>>;

template <typename T, typename Ms>
struct struct_ByFieldType;

template <typename T, typename Ms>
struct exception_ByFieldType;

template <typename T, typename Ms>
struct union_ByFieldType;
"""

CPP2_FOOTER = """
} // namespace detail

template <typename T, FieldModifier... Ms>
using struct_with = typename detail::struct_ByFieldType<T, detail::mod_set<Ms...>>::type;

template <typename T, FieldModifier... Ms>
using exception_with = typename detail::exception_ByFieldType<T, detail::mod_set<Ms...>>::type;

template <typename T, FieldModifier... Ms>
using union_with = typename detail::union_ByFieldType<T, detail::mod_set<Ms...>>::type;

} // namespace apache::thrift::test::testset
"""

PRIMITIVE_TYPES = (
    "bool",
    "byte",
    "i16",
    "i32",
    "i64",
    "float",
    "double",
    "binary",
    "string",
)

KEY_TYPES = (
    "string",
    "i64",
)

CPP2_TYPE_NS = "conformance::type"

PRIMATIVE_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "{}",
    Target.THRIFT: "{}",
    Target.CPP2: CPP2_TYPE_NS + "::{}_t",
}

STRUCT_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "struct_{}",
    Target.THRIFT: "struct {}",
    Target.CPP2: CPP2_TYPE_NS + "::struct_t<{}>",
}

UNION_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "union_{}",
    Target.THRIFT: "union {}",
    Target.CPP2: CPP2_TYPE_NS + "::union_t<{}>",
}

EXCEPTION_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "exception_{}",
    Target.THRIFT: "exception {}",
    Target.CPP2: CPP2_TYPE_NS + "::exception_t<{}>",
}

LIST_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "list_{}",
    Target.THRIFT: "list<{}>",
    Target.CPP2: CPP2_TYPE_NS + "::list<{}>",
}

SET_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "set_{}",
    Target.THRIFT: "set<{}>",
    Target.CPP2: CPP2_TYPE_NS + "::set<{}>",
}

MAP_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "map_{}_{}",
    Target.THRIFT: "map<{}, {}>",
    Target.CPP2: CPP2_TYPE_NS + "::map<{}, {}>",
}

OPTIONAL_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "optional_{}",
    Target.THRIFT: "optional {}",
    Target.CPP2: "{}|FieldModifier::Optional",
}

REQUIRED_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "required_{}",
    Target.THRIFT: "required {}",
    Target.CPP2: "{}|FieldModifier::Required",
}

CPP_REF_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "{}_cpp_ref",
    Target.THRIFT: "{}|cpp.ref",
    Target.CPP2: "{}|FieldModifier::Reference",
}

LAZY_TRANSFORM: Dict[Target, str] = {
    Target.NAME: "{}_lazy",
    Target.THRIFT: "{}|cpp.experimental.lazy",
    Target.CPP2: "{}|FieldModifier::Lazy",
}


def gen_primatives(
    target: Target, prims: Iterable[str] = PRIMITIVE_TYPES
) -> Dict[str, str]:
    result = {}
    for prim in prims:
        value = PRIMATIVE_TRANSFORM[target].format(prim)
        result[PRIMATIVE_TRANSFORM[Target.NAME].format(prim)] = value
    return result


def _gen_unary_tramsform(
    transform: Dict[Target, str], target: Target, values: Dict[str, str]
) -> Dict[str, str]:
    result = {}
    for name, value_t in values.items():
        result[transform[Target.NAME].format(name)] = transform[target].format(value_t)
    return result


def gen_lists(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(LIST_TRANSFORM, target, values)


def gen_sets(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(SET_TRANSFORM, target, values)


def gen_maps(
    target: Target, keys: Dict[str, str], values: Dict[str, str]
) -> Dict[str, str]:
    result = {}
    for key_name, key_t in keys.items():
        for value_name, value_t in values.items():
            name = MAP_TRANSFORM[Target.NAME].format(key_name, value_name)
            value = MAP_TRANSFORM[target].format(key_t, value_t)
            result[name] = value
    return result


def gen_optional(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(OPTIONAL_TRANSFORM, target, values)


def gen_required(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(REQUIRED_TRANSFORM, target, values)


def gen_cpp_ref(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(CPP_REF_TRANSFORM, target, values)


def gen_lazy(target: Target, values: Dict[str, str]) -> Dict[str, str]:
    return _gen_unary_tramsform(LAZY_TRANSFORM, target, values)


def gen_container_fields(target: Target) -> Dict[str, str]:
    """Generates field name -> type that are appropriate for use in unions."""
    prims = gen_primatives(target, PRIMITIVE_TYPES)
    keys = gen_primatives(target, KEY_TYPES)

    lists = gen_lists(target, prims)
    sets = gen_sets(target, keys)
    maps = gen_maps(target, keys, prims)

    maps_to_sets = gen_maps(target, keys, sets)

    return {**lists, **sets, **maps, **maps_to_sets}


def gen_union_fields(target: Target) -> Dict[str, str]:
    ret = gen_container_fields(target)
    ret.update(gen_cpp_ref(target, ret))
    ret.update(gen_primatives(target, PRIMITIVE_TYPES))
    return ret


def gen_lazy_fields(target: Target) -> Dict[str, str]:
    fields = gen_container_fields(target)
    fields.update(gen_primatives(target, ["string"]))
    return gen_lazy(target, fields)


def gen_struct_fields(target: Target) -> Dict[str, str]:
    """Generates field name -> type that are appropriate for use in structs."""
    ret = gen_union_fields(target)
    ret.update(**gen_optional(target, ret), **gen_required(target, ret))
    ret.update(**gen_lazy_fields(target))
    return ret


def gen_thrift_def(
    transform: Dict[Target, str], name: str, field_types: List[str]
) -> str:
    """Generate thrift struct from types
    >>> print(gen_thrift_def(STRUCT_TRANSFORM, "Foo", ["i64", "optional string", "set<i32>|cpp.ref"]))
    struct Foo {
      1: i64 field_1;
      2: optional string field_2;
      3: set<i32> field_3 (cpp.ref);
    } (thrift.uri="facebook.com/thrift/test/testset/Foo")
    """
    decl = transform[Target.THRIFT].format(name)
    lines = [f"{decl} {{"]
    for idx, field_type in enumerate(field_types):
        annotations = ""
        if "|" in field_type:
            v = field_type.split("|")
            field_type = v[0]
            annotations = " (" + ", ".join(v[1:]) + ")"
        lines.append(
            "  {0}: {1} field_{0}{2};".format(idx + 1, field_type, annotations)
        )
    lines.append(f'}} (thrift.uri="facebook.com/thrift/test/testset/{name}")')
    return "\n".join(lines)


def print_thrift_defs(
    transform: Dict[Target, str],
    fields: Dict[str, str],
    count: int = 1,
    *,
    file: TextIO = sys.stdout,
) -> List[str]:
    """Prints one thrift class def per field in fields and returns the names of all the classes."""
    empty_name = transform[Target.NAME].format("empty")
    print(gen_thrift_def(transform, empty_name, []), file=file)
    classes = [empty_name]
    for name, value_t in fields.items():
        class_name = transform[Target.NAME].format(name)
        classes.append(class_name)
        print(gen_thrift_def(transform, class_name, [value_t] * count), file=file)
    return classes


def gen_thrift(path: str) -> None:
    with open(path, "w") as file:
        print(THRIFT_HEADER, file=file)
        classes = []

        # Generate all structs.
        struct_fields = gen_struct_fields(Target.THRIFT)
        classes.extend(print_thrift_defs(STRUCT_TRANSFORM, struct_fields, file=file))

        # Generate all exceptions, with the struct fields.
        print_thrift_defs(EXCEPTION_TRANSFORM, struct_fields, file=file)

        # Generate all unions.
        union_fields = gen_union_fields(Target.THRIFT)
        classes.extend(
            print_thrift_defs(UNION_TRANSFORM, union_fields, count=2, file=file)
        )

        # Generate a struct of all defined structs and unions.
        all_struct_name = STRUCT_TRANSFORM[Target.NAME].format("all")
        print(gen_thrift_def(STRUCT_TRANSFORM, all_struct_name, classes), file=file)


CPP2_SPECIALIZE_TEMPLATE = """template <>
struct {}<{}, mod_set<{}>> {{
  using type = {};
}};
"""


def print_cpp2_specialization(
    transform: Dict[Target, str], fields: Dict[str, str], *, file: TextIO = sys.stdout
) -> None:
    for field, value_mods in fields.items():
        splits = value_mods.split("|")
        value_t = splits[0]
        mods = ", ".join(splits[1:])
        by_type = transform[Target.NAME].format("ByFieldType")
        name = transform[Target.NAME].format(field)
        print(CPP2_SPECIALIZE_TEMPLATE.format(by_type, value_t, mods, name), file=file)


def gen_cpp2(path: str) -> None:
    with open(path, "w") as file:
        print(CPP2_HEADER, file=file)

        # Generate specialization for all structs.
        struct_fields = gen_struct_fields(Target.CPP2)
        print_cpp2_specialization(STRUCT_TRANSFORM, struct_fields, file=file)

        # Generate specialization for all exceptions.
        print_cpp2_specialization(EXCEPTION_TRANSFORM, struct_fields, file=file)

        # Generate specialization for all unions.
        union_fields = gen_union_fields(Target.CPP2)
        print_cpp2_specialization(UNION_TRANSFORM, union_fields, file=file)

        print(CPP2_FOOTER, file=file)


def generate(dir: str) -> None:
    gen_thrift(os.path.join(dir, "testset.thrift"))
    gen_cpp2(os.path.join(dir, "Testset.h"))


def main() -> None:
    doctest.testmod()
    parser = argparse.ArgumentParser()
    parser.add_argument("--install_dir", required=True)
    args = parser.parse_args()
    generate(args.install_dir)


if __name__ == "__main__":
    main()
